# -*- coding:utf-8 -*-

import tensorflow as tf
from element.rnn_module import get_multi_lstm_cell
from tensorflow.python.ops.rnn import dynamic_rnn
from config.glob.global_pool import global_pool

"""
DRNN
"""


def net(self):
    # 取配置
    batch_size = global_pool.config.batch_size
    # num_steps = global_pool.config.xs_shape[0]
    cls_num = global_pool.embedding.vocab_size  # 类别数就是词汇表的数量
    state_size = global_pool.config.net.state_size  # 隐状态结点数
    layers_num = global_pool.config.net.rnn_layers_num  # drnn层数

    # todo size????
    if global_pool.config.embed.use:  # 是否用embedding, 英文字母不需要
        with tf.name_scope('embedding'), tf.device("/cpu:0"):
            embed = tf.get_variable('embedding', [cls_num, global_pool.config.embed.size])
            lstm_inputs = tf.nn.embedding_lookup(embed, self.xs)  # 选取xs对应的embedding
    else:   # 对于中文，需要使用embedding层
        lstm_inputs = tf.one_hot(self.xs, cls_num)

    with tf.name_scope('lstm'):
        cell = tf.nn.rnn_cell.MultiRNNCell(
            [get_multi_lstm_cell(state_size, global_pool.config.net.keep_prob) for _ in range(layers_num)]
        )
        initial_state = cell.zero_state(batch_size, tf.float32)
        lstm_outputs, final_state = dynamic_rnn(cell, lstm_inputs, initial_state=initial_state)
        seq_output = tf.concat(lstm_outputs, 1)  # 通过lstm_outputs得到概率
        # 输出接softmax
        x = tf.reshape(seq_output, [-1, state_size])
        with tf.variable_scope('softmax'):
            softmax_w = tf.Variable(tf.truncated_normal([state_size, cls_num], stddev=0.1))
            softmax_b = tf.Variable(tf.zeros(cls_num))
        logits = tf.matmul(x, softmax_w) + softmax_b
        proba_prediction = tf.nn.softmax(logits, name='predictions')
        self.y_pred = logits
        self.proba_prediction = proba_prediction
        self. initial_state = initial_state
        self.final_state = final_state
